// SPDX-FileCopyrightText: Copyright 2020 yuzu Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later

#version 430

#ifdef VULKAN

#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_8bit_storage : require
#define HAS_EXTENDED_TYPES 1
#define BEGIN_PUSH_CONSTANTS layout(push_constant) uniform PushConstants {
#define END_PUSH_CONSTANTS };
#define UNIFORM(n)
#define BINDING_SWIZZLE_BUFFER 0
#define BINDING_INPUT_BUFFER 1
#define BINDING_OUTPUT_IMAGE 2

#else // ^^^ Vulkan ^^^ // vvv OpenGL vvv

#extension GL_NV_gpu_shader5 : enable
#ifdef GL_NV_gpu_shader5
#define HAS_EXTENDED_TYPES 1
#else
#define HAS_EXTENDED_TYPES 0
#endif
#define BEGIN_PUSH_CONSTANTS
#define END_PUSH_CONSTANTS
#define UNIFORM(n) layout (location = n) uniform
#define BINDING_SWIZZLE_BUFFER 0
#define BINDING_INPUT_BUFFER 1
#define BINDING_OUTPUT_IMAGE 0

#endif

BEGIN_PUSH_CONSTANTS
UNIFORM(0) uvec3 origin;
UNIFORM(1) ivec3 destination;
UNIFORM(2) uint bytes_per_block_log2;
UNIFORM(3) uint layer_stride;
UNIFORM(4) uint block_size;
UNIFORM(5) uint x_shift;
UNIFORM(6) uint block_height;
UNIFORM(7) uint block_height_mask;
END_PUSH_CONSTANTS

layout(binding = BINDING_SWIZZLE_BUFFER, std430) readonly buffer SwizzleTable {
    uint swizzle_table[];
};

#if HAS_EXTENDED_TYPES
layout(binding = BINDING_INPUT_BUFFER, std430) buffer InputBufferU8 { uint8_t u8data[]; };
layout(binding = BINDING_INPUT_BUFFER, std430) buffer InputBufferU16 { uint16_t u16data[]; };
#endif
layout(binding = BINDING_INPUT_BUFFER, std430) buffer InputBufferU32 { uint u32data[]; };
layout(binding = BINDING_INPUT_BUFFER, std430) buffer InputBufferU64 { uvec2 u64data[]; };
layout(binding = BINDING_INPUT_BUFFER, std430) buffer InputBufferU128 { uvec4 u128data[]; };

layout(binding = BINDING_OUTPUT_IMAGE) uniform writeonly uimage2DArray output_image;

layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;

const uint GOB_SIZE_X = 64;
const uint GOB_SIZE_Y = 8;
const uint GOB_SIZE_Z = 1;
const uint GOB_SIZE = GOB_SIZE_X * GOB_SIZE_Y * GOB_SIZE_Z;

const uint GOB_SIZE_X_SHIFT = 6;
const uint GOB_SIZE_Y_SHIFT = 3;
const uint GOB_SIZE_Z_SHIFT = 0;
const uint GOB_SIZE_SHIFT = GOB_SIZE_X_SHIFT + GOB_SIZE_Y_SHIFT + GOB_SIZE_Z_SHIFT;

const uvec2 SWIZZLE_MASK = uvec2(GOB_SIZE_X - 1, GOB_SIZE_Y - 1);

uint SwizzleOffset(uvec2 pos) {
    pos = pos & SWIZZLE_MASK;
    return swizzle_table[pos.y * 64 + pos.x];
}

uvec4 ReadTexel(uint offset) {
    switch (bytes_per_block_log2) {
#if HAS_EXTENDED_TYPES
    case 0:
        return uvec4(u8data[offset], 0, 0, 0);
    case 1:
        return uvec4(u16data[offset / 2], 0, 0, 0);
#else
    case 0:
        return uvec4(bitfieldExtract(u32data[offset / 4], int((offset * 8) & 24), 8), 0, 0, 0);
    case 1:
        return uvec4(bitfieldExtract(u32data[offset / 4], int((offset * 8) & 16), 16), 0, 0, 0);
#endif
    case 2:
        return uvec4(u32data[offset / 4], 0, 0, 0);
    case 3:
        return uvec4(u64data[offset / 8], 0, 0);
    case 4:
        return u128data[offset / 16];
    }
    return uvec4(0);
}

void main() {
    uvec3 pos = gl_GlobalInvocationID + origin;
    pos.x <<= bytes_per_block_log2;

    // Read as soon as possible due to its latency
    const uint swizzle = SwizzleOffset(pos.xy);

    const uint block_y = pos.y >> GOB_SIZE_Y_SHIFT;

    uint offset = 0;
    offset += pos.z * layer_stride;
    offset += (block_y >> block_height) * block_size;
    offset += (block_y & block_height_mask) << GOB_SIZE_SHIFT;
    offset += (pos.x >> GOB_SIZE_X_SHIFT) << x_shift;
    offset += swizzle;

    const uvec4 texel = ReadTexel(offset);
    const ivec3 coord = ivec3(gl_GlobalInvocationID) + destination;
    imageStore(output_image, coord, texel);
}
